package com.luo.d3s.ext.data.permission.interceptor;

import com.luo.d3s.ext.data.permission.anno.DataPermission;
import com.luo.d3s.ext.data.permission.config.DataPermissionProps;
import com.luo.d3s.ext.data.permission.context.DpUserContextHolder;
import com.luo.d3s.ext.data.permission.dto.context.base.BaseDpDto;
import com.luo.d3s.ext.data.permission.dto.context.base.BaseUserDto;
import com.luo.d3s.ext.data.permission.dto.parse.SqlParseParamDto;
import com.luo.d3s.ext.data.permission.enums.DpOpEnum;
import com.luo.d3s.ext.data.permission.enums.DpTypes;
import com.luo.d3s.ext.data.permission.exception.DataPermissionException;
import com.luo.d3s.ext.data.permission.util.DpUtils;
import com.luo.d3s.ext.data.permission.util.SqlConditionUtils;
import org.apache.ibatis.cache.CacheKey;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.*;
import org.apache.ibatis.plugin.*;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.core.annotation.AnnotatedElementUtils;

import java.lang.reflect.Method;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Supplier;
import java.util.stream.Stream;


/**
 * 数据权限 - Mybatis Interceptor
 *
 * @author luohq
 * @date 2023-06-18
 */
@Intercepts({
        @Signature(type = Executor.class, method = "update", args = {MappedStatement.class, Object.class}),
        @Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}),
        @Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class, CacheKey.class, BoundSql.class})
})
public class DataPermissionInterceptor implements Interceptor {

    private static final Logger log = LoggerFactory.getLogger(DataPermissionInterceptor.class);

    /**
     * Executor.query()带有cacheKey和boundSql参数的方法的参数个数
     */
    private static final Integer EXECUTOR_QUERY_CACHE_ARGS_COUNT = 6;

    /**
     * 数据权限条件占位符
     */
    private static final String DATA_PERMISSION_CONDITION_PLACEHOLDER = "{DATA_PERMISSION_CONDITION}";

    /**
     * 空白字符正则表达式
     */
    private static final String BLANK_CHAR_REGEX = "[\\t\\n\\r]";
    /**
     * 空格字符串
     */
    private static final String SPACE_STRING = " ";

    /**
     * 缓存Map(mapperMethodId, 对应的@DataPermission注解)
     */
    private Map<String, DataPermission> mapperMethodIdToDpAnnoMap = new ConcurrentHashMap<>();

    /**
     * 数据权限配置属性
     */
    private DataPermissionProps dpProps;

    /**
     * 数据权限类型对应的条件SQL属性
     */
    private Map<String, Supplier<String>> dpTypeToConditionSqlProplMap = new HashMap<>(6);

    /**
     * 数据权限烂机器 - 构造函数
     *
     * @param dpProps 数据权限配置属性
     */
    public DataPermissionInterceptor(DataPermissionProps dpProps) {
        this.dpProps = dpProps;
        //初始化
        this.init();
    }

    /**
     * 初始数据权限类型对应的条件SQL属性
     */
    private void init() {
        //初始预置权限类型对应的限制条件SQL
        dpTypeToConditionSqlProplMap.put(DpTypes.USER, this.dpProps::getConditionForUser);
        dpTypeToConditionSqlProplMap.put(DpTypes.USER_CUSTOM, this.dpProps::getConditionForUserCustom);
        dpTypeToConditionSqlProplMap.put(DpTypes.DEPT, this.dpProps::getConditionForDept);
        dpTypeToConditionSqlProplMap.put(DpTypes.DEPT_AND_CHILD, this.dpProps::getConditionForDeptAndChild);
        dpTypeToConditionSqlProplMap.put(DpTypes.DEPT_CUSTOM, this.dpProps::getConditionForDeptCustom);
        //初始其他自定义权限类型对应的限制条件SQL
        this.dpProps.getConditionsForOther().forEach((dpType, conditionSql) -> dpTypeToConditionSqlProplMap.put(dpType, () -> conditionSql));
    }

    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        //获取执行参数
        Object[] args = invocation.getArgs();
        MappedStatement ms = (MappedStatement) args[0];
        //当前SQL命令类型 - UNKNOWN, INSERT, UPDATE, DELETE, SELECT, FLUSH
        SqlCommandType sqlCommandType = ms.getSqlCommandType();
        //mapper方法参数
        Object paramObjOfMapperMethod = args[1];
        //获取BoundSql（区分处理2个query方法）
        BoundSql boundSql = EXECUTOR_QUERY_CACHE_ARGS_COUNT.equals(args.length)
                ? (BoundSql) args[EXECUTOR_QUERY_CACHE_ARGS_COUNT - 1]
                : ms.getSqlSource().getBoundSql(paramObjOfMapperMethod);
        //Mapper方法ID，格式: mapper接口名全路径.mapper方法名，例如com.luo.dao.BizMapper.insert
        String mapperMethodId = ms.getId();


        //解析当前执行的Mapper方法
        DataPermission dpAnno = this.parseMapperMethodDpAnno(mapperMethodId);
        //如果没有数据权限注解,则继续执行逻辑
        if (Objects.isNull(dpAnno)) {
            //继续执行逻辑
            return invocation.proceed();
        }

        //判断是否已设置当前用户数据权限上下文
        BaseUserDto dpUser = DpUserContextHolder.getContext();
        if (Objects.isNull(dpUser)) {
            throw new RuntimeException("无法获取当前用户数据权限信息 - 请先执行DpUserContextHolder.setContext(dpUser)方法");
        }

        //提取原SQL
        String oldSql = boundSql.getSql().replaceAll(BLANK_CHAR_REGEX, SPACE_STRING);
        //拼接生成新SQL - 附加数据权限条件
        String newSqlWithDataPermission = this.fillSqlWithDataPermissionCondition(mapperMethodId, oldSql, sqlCommandType, dpAnno, dpUser);
        log.debug("DataPermissionInterceptor[{}] SQL Before Refactoring: {}", mapperMethodId, oldSql);
        log.debug("DataPermissionInterceptor[{}] SQL After  Refactoring: {}", mapperMethodId, newSqlWithDataPermission);

        //替换原SQL
        this.replaceSql(newSqlWithDataPermission, ms, boundSql, invocation);

        //继续执行逻辑
        return invocation.proceed();
    }

    @Override
    public Object plugin(Object o) {
        //获取代理权
        if (o instanceof Executor) {
            //如果是Executor（执行增删改查操作），则拦截下来
            return Plugin.wrap(o, this);
        } else {
            return o;
        }
    }

    @Override
    public void setProperties(Properties properties) {
        //读取mybatis配置文件中属性
    }

    /**
     * 解析Mapper方法上的@DataPermission注解（或者Mapper接口上的@DataPermission注解）
     *
     * @param mapperMethodId mapper方法ID，格式: mapper接口名全路径.mapper方法名，例如com.luo.dao.BizMapper.insert
     * @return Mapper方法对应的@DataPermission注解
     * @throws ClassNotFoundException
     */
    private DataPermission parseMapperMethodDpAnno(String mapperMethodId) throws ClassNotFoundException {
        //优先从缓存中获取mapperMethodId对应的@DataPermission注解
        if (this.mapperMethodIdToDpAnnoMap.containsKey(mapperMethodId)) {
            return this.mapperMethodIdToDpAnnoMap.get(mapperMethodId);
        }


        int lastDotIndex = mapperMethodId.lastIndexOf(".");
        //Mapper接口类
        String mapperClassFromId = mapperMethodId.substring(0, lastDotIndex);
        //Mapper接口方法名
        String mapperMethodNameFromId = mapperMethodId.substring((lastDotIndex + 1));
        //反射Mapper接口类
        Class<?> mapperClass = Class.forName(mapperClassFromId);
        Method[] mapperMethods = mapperClass.getMethods();
        //获取当前执行的mapper方法
        Method mapperMethod = null;
        for (Method method : mapperMethods) {
            String methodName = method.getName();
            //匹配当前执行的mapper方法
            if (this.matchMapperMethod(methodName, mapperMethodNameFromId)) {
                mapperMethod = method;
                break;
            }
        }

        //方法不匹配，则无需拦截
        if (Objects.isNull(mapperMethod)) {
            return null;
        }


        //解析当前方法的DataPermission注解
        DataPermission dpAnnoOfMethod = AnnotatedElementUtils.getMergedAnnotation(mapperMethod, DataPermission.class);
        if (Objects.nonNull(dpAnnoOfMethod)) {
            //缓存mapperMethodId对应的@DataPermission注解
            this.mapperMethodIdToDpAnnoMap.put(mapperMethodId, dpAnnoOfMethod);
            return dpAnnoOfMethod;
        }

        //解析类上的DataPermission注解（Repeatable支持解析多注解）
        Set<DataPermission> dpAnnoSetOfClass = AnnotatedElementUtils.getMergedRepeatableAnnotations(mapperClass, DataPermission.class);
        for (DataPermission dpAnnoOfClass : dpAnnoSetOfClass) {
            //匹配当前执行的mapper方法
            if (Objects.nonNull(dpAnnoOfClass.methodName()) && this.matchMapperMethod(dpAnnoOfClass.methodName(), mapperMethodNameFromId)) {
                //缓存mapperMethodId对应的@DataPermission注解
                this.mapperMethodIdToDpAnnoMap.put(mapperMethodId, dpAnnoOfClass);
                return dpAnnoOfClass;
            }
        }

        //方法上没有@DataPermission则返回null
        return null;
    }

    /**
     * 添加数据权限SQL条件
     *
     * @param mapperMethodId mapper方法ID，格式: mapper接口名全路径.mapper方法名，例如com.luo.dao.BizMapper.insert
     * @param oldSql         原始SQL
     * @param sqlCommandType SQL命令类型
     * @param dpAnno         Mapper方法上的数据权限注解
     * @param dpUser         当前用户数据权限上下文
     * @return
     */
    private String fillSqlWithDataPermissionCondition(String mapperMethodId, String oldSql, SqlCommandType sqlCommandType, DataPermission dpAnno, BaseUserDto dpUser) {
        //若无匹配的数据权限，是否允许查询全部数据
        String defaultAllowAll = String.valueOf(dpAnno.defaultAllowAll());

        //若当前用户没有数据权限
        if (Objects.isNull(dpUser.getDataPermissions()) || dpUser.getDataPermissions().isEmpty()) {
            //根据defaultAllowAll属性，判断允许或不允许查询全部数据
            return this.fillSqlWithFinalDpCondition(mapperMethodId, oldSql, sqlCommandType, defaultAllowAll, dpUser);
        }


        //获取当前用户的数据权限（非空）
        Collection<BaseDpDto> dpCollection = dpUser.getDataPermissions();
        String dpCondition = "";
        //遍历用户拥有的数据权限集合，根据不同权限类型依次拼接数据权限Sql条件
        for (BaseDpDto curDp : dpCollection) {
            String curDpType = curDp.getType();
            //转换sql填充参数
            SqlParseParamDto sqlParseParamDto = SqlParseParamDto.of(dpAnno, dpUser, curDp);

            //最高权限 - 查询全部
            if (DpTypes.ADMIN.equals(curDpType)) {
                //直接在原SQL上拼接true条件
                return this.fillSqlWithFinalDpCondition(mapperMethodId, oldSql, sqlCommandType, SqlConditionUtils.ALLOW_ALL_CONDITION, dpUser);
            }

            //当前Mapper方法是否支持该数据权限类型和操作类型
            if (this.supportDpTypeAndOperation(dpAnno, curDp, sqlCommandType)) {
                //其他权限 - 拼接条件
                String conditionForCurDpType = this.dpTypeToConditionSqlProplMap.get(curDpType).get();
                conditionForCurDpType = sqlParseParamDto.fillSqlParams(conditionForCurDpType);
                dpCondition = SqlConditionUtils.appendOrCondition(dpCondition, conditionForCurDpType);
            }
        }

        //是否拥有数据权限
        Boolean hasDataPermission = !dpCondition.isEmpty();
        //doCondition = ((dp_condition1) or (dp_condition2))
        dpCondition = hasDataPermission
                //删除前置 OR
                ? dpCondition.substring(SqlConditionUtils.OR_SEPARATOR.length())
                //若无匹配的数据权限类型，默认则不允许查询全部（false）
                : defaultAllowAll;

        //TODO 特殊处理Insert，若dpCondition为空，则抛出数据权限异常
        //TODO 根据commandType限制SQL执行类型 - UNKNOWN, INSERT, UPDATE, DELETE, SELECT, FLUSH

        //替换原Sql中的数据权限占位符{DATA_PERMISSION_CONDITION}为dpCondition
        //若不存在数据权限占位符{DATA_PERMISSION_CONDITION}，则拼接dpCondition到最后
        return this.fillSqlWithFinalDpCondition(mapperMethodId, oldSql, sqlCommandType, dpCondition, dpUser);
    }

    /**
     * 填充最终的数据权限SQL条件
     * <ol>
     *     <li>替换原Sql中的数据权限占位符{DATA_PERMISSION_CONDITION}为dpCondition</li>
     *     <li>若为默认全部数据权限true，则不拼接数据权限条件</li>
     *     <li>若不存在数据权限占位符{DATA_PERMISSION_CONDITION} 且 不是默认全部数据权限，则拼接dpCondition到最后</li>
     * </ol>
     *
     * @param oldSql      原始SQL
     * @param dpCondition 拼接后的数据权限SQL条件
     * @return 填充后的SQL
     */
    private String fillSqlWithFinalDpCondition(String mapperMethodId, String oldSql, SqlCommandType sqlCommandType, String dpCondition, BaseUserDto dpUser) {
        //若无数据权限，则特殊处理INSERT 或 依照配置抛出DataPermissionException
        if (SqlConditionUtils.isDenyAllCondition(dpCondition)
                && (SqlCommandType.INSERT.equals(sqlCommandType) || this.dpProps.getThrowExceptionWhenNoDataPermission())) {
            throw new DataPermissionException(mapperMethodId, sqlCommandType.name(), dpUser);
        }

        //替换原Sql中的数据权限占位符{DATA_PERMISSION_CONDITION}为：(dp_condition)
        if (oldSql.contains(DATA_PERMISSION_CONDITION_PLACEHOLDER)) {
            return oldSql.replace(DATA_PERMISSION_CONDITION_PLACEHOLDER, SqlConditionUtils.formatBracketConditionWithParams(dpCondition));
        }
        //若为默认全部数据权限，则不拼接数据权限条件
        if (SqlConditionUtils.isAllowAllCondition(dpCondition)) {
            return oldSql;
        }

        //若不存在数据权限占位符{DATA_PERMISSION_CONDITION} 且 不是默认全部数据权限，则拼接数据权限条件
        return SqlConditionUtils.appendWhereAndCondition(oldSql, dpCondition);
    }


    /**
     * 当前Mapper方法是否支持该数据权限类型和操作类型<br/>
     * <ol>
     *     <li>若@DataPermission.supportTypes为空，则表示支持所有权限类型</li>
     *     <li>若BaseDpDto.operations为空，则表示支持所有SqlCommandType</li>
     *     <li>若@DataPermission.supportTypes非空，则需要与BaseDpTo.type进行匹配</li>
     *     <li>若BaseDpDto.operations非空，则需要与当前SqlCommandType进行匹配</li>
     *     <li>若BaseDpDto.operations包含ALL，则表示支持所有SqlCommandType</li>
     * </ol>
     *
     * @param dpAnno         Mapper方法上的数据权限注解
     * @param curDpDto       当前待处理数据权限
     * @param sqlCommandType Sql命令类型
     * @return 是否支持
     */
    private Boolean supportDpTypeAndOperation(DataPermission dpAnno, BaseDpDto curDpDto, SqlCommandType sqlCommandType) {
        //是否支持数据权限类型
        Boolean matchDpType = DpUtils.isEmptyArray(dpAnno.supportTypes())
                ? true
                : Stream.of(dpAnno.supportTypes()).anyMatch(supportDpType -> supportDpType.equals(curDpDto.getType()));
        //是否支持数据权限操作
        Boolean matchDpOperation = DpUtils.isEmptyCollection(curDpDto.getOperations())
                ? true
                : curDpDto.getOperations().stream().anyMatch(operation -> DpOpEnum.ALL.equals(operation) || operation.getSqlCommandType().equals(sqlCommandType.name()));
        return matchDpType && matchDpOperation;
    }

    /**
     * Mapper接口中的方法是否匹配当前Mybatis拦截器拦截到的方法
     *
     * @param curMethodName          mapper中的方法名
     * @param mapperMethodNameFromId 当前Mybatis拦截器拦截到的方法名
     * @return
     */
    private Boolean matchMapperMethod(String curMethodName, String mapperMethodNameFromId) {
        //"_COUNT"兼容PageHelper自定义分页sql
        return curMethodName.equals(mapperMethodNameFromId) || curMethodName.concat("_COUNT").equals(mapperMethodNameFromId);
    }

    /**
     * 替换原SQL
     *
     * @param newSql     新SQL语句
     * @param ms         原MappedStatement
     * @param boundSql   原BoundSql
     * @param invocation Invocation
     */
    private void replaceSql(String newSql, MappedStatement ms, BoundSql boundSql, Invocation invocation) {
        //创建新的BoundSql
        BoundSql newBoundSql = new BoundSql(ms.getConfiguration(), newSql, boundSql.getParameterMappings(), boundSql.getParameterObject());
        for (ParameterMapping mapping : boundSql.getParameterMappings()) {
            String prop = mapping.getProperty();
            if (boundSql.hasAdditionalParameter(prop)) {
                newBoundSql.setAdditionalParameter(prop, boundSql.getAdditionalParameter(prop));
            }
        }

        //创建新的MappedStatement
        MappedStatement newMs = newMappedStatement(ms, parameterObject -> newBoundSql);
        Object[] queryArgs = invocation.getArgs();

        //替换参数MappedStatement
        queryArgs[0] = newMs;
        //替换参数BoundSql
        if (EXECUTOR_QUERY_CACHE_ARGS_COUNT.equals(queryArgs.length)) {
            queryArgs[EXECUTOR_QUERY_CACHE_ARGS_COUNT - 1] = newBoundSql;
        }
    }

    /**
     * 创建新的MappedStatement
     *
     * @param ms           原MappedStatement
     * @param newSqlSource 新的SqlSource
     * @return 新的MappedStatement
     */
    private MappedStatement newMappedStatement(MappedStatement ms, SqlSource newSqlSource) {
        MappedStatement.Builder builder = new MappedStatement.Builder(ms.getConfiguration(), ms.getId(), newSqlSource, ms.getSqlCommandType());
        builder.resource(ms.getResource());
        builder.fetchSize(ms.getFetchSize());
        builder.statementType(ms.getStatementType());
        builder.keyGenerator(ms.getKeyGenerator());
        if (ms.getKeyProperties() != null && ms.getKeyProperties().length > 0) {
            builder.keyProperty(ms.getKeyProperties()[0]);
        }
        builder.timeout(ms.getTimeout());
        builder.parameterMap(ms.getParameterMap());
        builder.resultMaps(ms.getResultMaps());
        builder.resultSetType(ms.getResultSetType());
        builder.cache(ms.getCache());
        builder.flushCacheRequired(ms.isFlushCacheRequired());
        builder.useCache(ms.isUseCache());
        return builder.build();
    }
}